{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Self-attention does not need O(n^2) memory.ipynb",
      "provenance": [],
      "collapsed_sections": [
        "brHNYfaZ4KU3"
      ]
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Self-attention does not need O(n^2) memory\n",
        "\n",
        "This colab accompanies our [paper](https://arxiv.org/abs/2112.05682) on a memory-efficient implementation of (self-)attention. It contains the standard attention implementation, our memory-efficient attention implementation, and evaluation code to determine and compare their runtime performance.\n",
        "\n",
        "Please remember to connect to a colab runtime with a TPU. You can check whether the runtime you are using has a TPU, by hovering over the RAM and Disk indicator in the top right.\n",
        "\n",
        "In case you have questions, please send us an email to mrabe@google.com and cstaats@google.com.\n"
      ],
      "metadata": {
        "id": "DQj7efd7XZQu"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## License"
      ],
      "metadata": {
        "id": "brHNYfaZ4KU3"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Copyright 2021 Google LLC.\n",
        "\n",
        "SPDX-License-Identifier: Apache-2.0\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "   you may not use this file except in compliance with the License.\n",
        "   You may obtain a copy of the License at\n",
        "\n",
        "       http://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "   Unless required by applicable law or agreed to in writing, software\n",
        "   distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "   See the License for the specific language governing permissions and\n",
        "   limitations under the License."
      ],
      "metadata": {
        "id": "rxveME4k28lv"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Imports and utilities"
      ],
      "metadata": {
        "id": "p_J2imNhL00v"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OcxJTUqSnAJM"
      },
      "outputs": [],
      "source": [
        "from jax import lax\n",
        "from typing import Any, Optional\n",
        "from jax import numpy as jnp\n",
        "import numpy as np\n",
        "import time\n",
        "import jax\n",
        "import timeit\n",
        "import functools\n",
        "\n",
        "import jax.tools.colab_tpu\n",
        "try:\n",
        "  jax.tools.colab_tpu.setup_tpu()\n",
        "except KeyError:\n",
        "  print('Could not find a TPU; running this colab on CPU or GPU.')"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Utility to create random data\n",
        "\n",
        "_cur_key = jax.random.PRNGKey(4)\n",
        "\n",
        "def fresh():\n",
        "  global _cur_key\n",
        "  _cur_key, result = jax.random.split(_cur_key)\n",
        "  return result\n",
        "\n",
        "num_heads = 1\n",
        "feature_dims = 64\n",
        "\n",
        "def fresh_qkv(size, dtype=jnp.bfloat16):\n",
        "  qkv_shape = (size, num_heads, feature_dims)\n",
        "  return jax.random.normal(fresh(), qkv_shape, dtype=dtype)\n",
        "\n",
        "fresh_qkv(2)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "8f7OE_-LnJHl",
        "outputId": "1e466f23-4c88-43e3-c005-542e1573ce7d"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "DeviceArray([[[-1, 0.201172, -1, 0.984375, -1.40625, 0.34375, 0.0439453,\n",
              "               0.65625, 0.679688, 0.894531, -1, 0.261719, 0.515625,\n",
              "               -0.824219, -0.192383, 0.494141, -0.0932617, 1.01562,\n",
              "               0.162109, 0.0439453, 1.19531, 0.8125, -0.3125, -1.17969,\n",
              "               0.222656, -1.10156, 1.54688, 0.283203, 0.261719,\n",
              "               0.515625, 1.54688, -0.597656, -1.65625, -0.251953,\n",
              "               -0.851562, -1.03125, -1.73438, 0.632812, 1.49219,\n",
              "               -0.398438, -0.3125, 0.679688, -0.851562, 0.103027,\n",
              "               1.70312, 0.957031, 0.322266, 0.103027, 0.921875,\n",
              "               -2.10938, 1.125, -0.0932617, -0.152344, -0.271484,\n",
              "               0.0439453, 0.241211, -0.192383, 0.0634766, 0.283203,\n",
              "               1.78906, -0.353516, 0.957031, -0.851562, 0.957031]],\n",
              "\n",
              "             [[-1.14062, -0.333984, -0.667969, -2.32812, -1.17969,\n",
              "               0.408203, -1.10156, -0.353516, 0.757812, -0.824219,\n",
              "               -0.527344, 0.609375, 0.8125, -0.769531, -1.40625,\n",
              "               -1.10156, -0.746094, -0.621094, 0.162109, -1.40625,\n",
              "               -0.796875, -0.0932617, 0.494141, 1.04688, 1.19531,\n",
              "               2.20312, -1.73438, 0.386719, 0.00488281, 0.957031,\n",
              "               0.0439453, -0.112305, -0.527344, -0.667969, 0.703125,\n",
              "               0.5625, -1.0625, 0.515625, -0.460938, -0.503906,\n",
              "               -0.910156, 0.201172, -1.40625, -1.83594, -0.232422, -1,\n",
              "               0.494141, 0.0439453, -0.972656, 0.142578, -2.89062,\n",
              "               0.386719, 0.123047, -0.375, 1.54688, 0.515625, -2.89062,\n",
              "               -1.21875, 0.984375, 0.632812, 0.103027, 0.241211,\n",
              "               -0.0539551, 0.703125]]], dtype=bfloat16)"
            ]
          },
          "metadata": {},
          "execution_count": 2
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Standard attention"
      ],
      "metadata": {
        "id": "v6qCuy6-L-gr"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Based on flax/linen/attention.py\n",
        "\n",
        "def standard_attention(query, key, value, dtype=jnp.float32, precision=None):\n",
        "  depth = query.shape[-1]\n",
        "  query = query / jnp.sqrt(depth).astype(dtype)\n",
        "  # attn weight shape is (batch..., num_heads, q_length, kv_length)\n",
        "  attn_weights = jnp.einsum('...qhd,...khd->...hqk', query, key,\n",
        "                            precision=precision)\n",
        "\n",
        "  # normalize the attention weights\n",
        "  attn_weights = jax.nn.softmax(attn_weights).astype(dtype)\n",
        "  return jnp.einsum('...hqk,...khd->...qhd', attn_weights, value,\n",
        "                    precision=precision)"
      ],
      "metadata": {
        "id": "amM6wyRPnMb9"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Memory-efficient Self-Attention"
      ],
      "metadata": {
        "id": "THNUMDsyMELu"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# This cell is self-contained; the following imports suffice to run the\n",
        "# memory-efficient attention implementation.\n",
        "import functools, jax, math\n",
        "from jax import lax\n",
        "from jax import numpy as jnp\n",
        "\n",
        "\n",
        "def _query_chunk_attention(query,\n",
        "                           key,\n",
        "                           value,\n",
        "                           key_chunk_size=4096,\n",
        "                           precision=lax.Precision.HIGHEST,\n",
        "                           dtype=jnp.float32):\n",
        "  num_kv, num_heads, k_features = key.shape\n",
        "  v_features = value.shape[-1]\n",
        "  key_chunk_size = min(key_chunk_size, num_kv)\n",
        "  query = query / jnp.sqrt(k_features).astype(dtype)\n",
        "\n",
        "  @functools.partial(jax.checkpoint, prevent_cse=False)\n",
        "  def summarize_chunk(query, key, value):\n",
        "    attn_weights = jnp.einsum(\n",
        "        'qhd,khd->qhk', query, key, precision=precision).astype(dtype)\n",
        "    max_score = jnp.max(attn_weights, axis=-1, keepdims=True)\n",
        "    max_score = jax.lax.stop_gradient(max_score)\n",
        "    exp_weights = jnp.exp(attn_weights - max_score)\n",
        "    exp_values = jnp.einsum(\n",
        "        'vhf,qhv->qhf', value, exp_weights, precision=precision).astype(dtype)\n",
        "    return (exp_values, exp_weights.sum(axis=-1),\n",
        "            max_score.reshape((query.shape[0], num_heads)))\n",
        "\n",
        "  def chunk_scanner(chunk_idx):\n",
        "    key_chunk = lax.dynamic_slice(\n",
        "        key, (chunk_idx, 0, 0),\n",
        "        slice_sizes=(key_chunk_size, num_heads, k_features))\n",
        "    value_chunk = lax.dynamic_slice(\n",
        "        value, (chunk_idx, 0, 0),\n",
        "        slice_sizes=(key_chunk_size, num_heads, v_features))\n",
        "    return summarize_chunk(query, key_chunk, value_chunk)\n",
        "\n",
        "  chunk_values, chunk_weights, chunk_max = lax.map(\n",
        "      chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))\n",
        "\n",
        "  global_max = jnp.max(chunk_max, axis=0, keepdims=True)\n",
        "  max_diffs = jnp.exp(chunk_max - global_max)\n",
        "  chunk_values *= jnp.expand_dims(max_diffs, axis=-1)\n",
        "  chunk_weights *= max_diffs\n",
        "\n",
        "  all_values = chunk_values.sum(axis=0)\n",
        "  all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)\n",
        "  return all_values / all_weights\n",
        "\n",
        "\n",
        "def mefficient_attention(query,\n",
        "                         key,\n",
        "                         value,\n",
        "                         query_chunk_size=1024,\n",
        "                         precision=jax.lax.Precision.HIGHEST,\n",
        "                         dtype=jnp.float32):\n",
        "  num_q, num_heads, q_features = query.shape\n",
        "\n",
        "  def chunk_scanner(chunk_idx, _):\n",
        "    query_chunk = lax.dynamic_slice(\n",
        "        query, (chunk_idx, 0, 0),\n",
        "        slice_sizes=(min(query_chunk_size, num_q), num_heads, q_features))\n",
        "    return (chunk_idx + query_chunk_size,\n",
        "            _query_chunk_attention(\n",
        "                query_chunk, key, value, precision=precision, dtype=dtype))\n",
        "\n",
        "  _, res = lax.scan(\n",
        "      chunk_scanner,\n",
        "      init=0,\n",
        "      xs=None,\n",
        "      length=math.ceil(num_q / query_chunk_size))\n",
        "  return res.reshape(num_q, num_heads, value.shape[-1])"
      ],
      "metadata": {
        "id": "KzlCGYqWnPoV"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Evaluation"
      ],
      "metadata": {
        "id": "9AAAaJTkMI0B"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Compare inference performance\n",
        "\n",
        "# The evaluation uses mixed-precision by default (bfloat16 for the inputs and\n",
        "# outputs, and float32 for certain internal representations.)\n",
        "input_dtype = jnp.bfloat16\n",
        "dtype = jnp.float32\n",
        "# Using HIGHEST means that we use full float32 precision. For neural network \n",
        "# training we can often use lax.Precision.DEFAULT instead.\n",
        "precision = lax.Precision.HIGHEST\n",
        "\n",
        "execute_standard_att = True\n",
        "execute_memory_efficient_att = True\n",
        "repeats = 50\n",
        "\n",
        "for i in range(8, 15, 2):\n",
        "  q_size = 2**i\n",
        "  memsize = 2**i\n",
        "  print('\\nAttention size:', q_size, 'x', memsize)\n",
        "  query, key, value = fresh_qkv(q_size, input_dtype), fresh_qkv(memsize, input_dtype), fresh_qkv(memsize, input_dtype)\n",
        "\n",
        "  if execute_standard_att:\n",
        "    _orig_attn = functools.partial(standard_attention,\n",
        "                      precision=precision,\n",
        "                      dtype=dtype)\n",
        "    standard_attn = jax.jit(_orig_attn)\n",
        "\n",
        "    compilation_start = time.time()\n",
        "    compilation_res = standard_attn(query, key, value)\n",
        "    compilation_res.block_until_ready()\n",
        "    print('Standard compilation time:', time.time() - compilation_start)\n",
        "\n",
        "    total_time = 0.0\n",
        "    for _ in range(repeats):\n",
        "      start = time.time()\n",
        "      res_std = standard_attn(query, key, value)\n",
        "      res_std.block_until_ready()\n",
        "      # print('Time of op:', time.time() - start)\n",
        "      total_time += (time.time() - start)\n",
        "    total_time = total_time / repeats\n",
        "    print('Standard attention took:', total_time)\n",
        "\n",
        "  if execute_memory_efficient_att:\n",
        "    _memory_efficient_attn = functools.partial(\n",
        "        mefficient_attention,\n",
        "        precision=precision,\n",
        "        dtype=dtype)\n",
        "    mefficient_attn = jax.jit(_memory_efficient_attn)\n",
        "    \n",
        "    compilation_start = time.time()\n",
        "    compilation_res = mefficient_attn(query, key, value)\n",
        "    compilation_res.block_until_ready()\n",
        "    print('Memory-efficient attention compilation time:', time.time() - compilation_start)\n",
        "\n",
        "    total_time_mem = 0.0\n",
        "    for _ in range(repeats):\n",
        "      start = time.time()\n",
        "      res = mefficient_attn(query, key, value)\n",
        "      res.block_until_ready()\n",
        "      total_time_mem += (time.time() - start)\n",
        "    total_time_mem = total_time_mem / repeats\n",
        "    print('Memory-efficient attention took:', total_time_mem)\n",
        "\n",
        "  if execute_standard_att and execute_memory_efficient_att:\n",
        "    print('Performance advantage:', (total_time / total_time_mem) - 1.0)\n",
        "    diff = res - res_std\n",
        "    print('avg difference', jnp.average(jnp.abs(diff)))\n",
        "    print('max difference', jnp.max(jnp.abs(diff)))\n",
        "    np.testing.assert_allclose(res.astype(jnp.float32), res_std.astype(jnp.float32), atol=1e-2)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2PUZWqnInd3k",
        "outputId": "c0bcf1e4-f7ec-4ef5-d859-d2b156241aab"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "Attention size: 256 x 256\n",
            "Standard compilation time: 0.10895419120788574\n",
            "Standard attention took: 0.0018677806854248047\n",
            "Memory-efficient attention compilation time: 0.31549072265625\n",
            "Memory-efficient attention took: 0.001740708351135254\n",
            "Performance advantage: 0.07300035885200229\n",
            "avg difference 7.912317e-09\n",
            "max difference 8.940697e-08\n",
            "\n",
            "Attention size: 1024 x 1024\n",
            "Standard compilation time: 0.18660283088684082\n",
            "Standard attention took: 0.0019140434265136719\n",
            "Memory-efficient attention compilation time: 0.3578367233276367\n",
            "Memory-efficient attention took: 0.0018649005889892578\n",
            "Performance advantage: 0.02635145155434193\n",
            "avg difference 5.8978773e-09\n",
            "max difference 8.940697e-08\n",
            "\n",
            "Attention size: 4096 x 4096\n",
            "Standard compilation time: 0.5202584266662598\n",
            "Standard attention took: 0.0029724836349487305\n",
            "Memory-efficient attention compilation time: 0.4849257469177246\n",
            "Memory-efficient attention took: 0.0031867408752441406\n",
            "Performance advantage: -0.06723396996594388\n",
            "avg difference 5.1832956e-09\n",
            "max difference 8.940697e-08\n",
            "\n",
            "Attention size: 16384 x 16384\n",
            "Standard compilation time: 1.089770793914795\n",
            "Standard attention took: 0.023688259124755858\n",
            "Memory-efficient attention compilation time: 0.6026222705841064\n",
            "Memory-efficient attention took: 0.025738215446472167\n",
            "Performance advantage: -0.079646404622714\n",
            "avg difference 2.0201878e-08\n",
            "max difference 1.6391277e-07\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Compare differentiation performance\n",
        "\n",
        "repeats = 50\n",
        "\n",
        "execute_standard_att = True\n",
        "execute_memory_efficient_att = True\n",
        "\n",
        "input_dtype = jnp.bfloat16\n",
        "dtype = jnp.float32\n",
        "precision = lax.Precision.HIGHEST\n",
        "\n",
        "\n",
        "def loss_ckpt(query, key, value):\n",
        "  return jnp.sum(mefficient_attention(query, key, value, precision=precision))\n",
        "\n",
        "def loss_simp(query, key, value):\n",
        "  return jnp.sum(standard_attention(query, key, value, precision=precision))\n",
        "\n",
        "diff_mefficient_attention = jax.jit(jax.grad(loss_ckpt, argnums=[0,1,2]))\n",
        "diff_attention_simp = jax.jit(jax.grad(loss_simp, argnums=[0,1,2]))\n",
        "\n",
        "for i in range(8, 15, 1):\n",
        "  q_size = 2**i\n",
        "  memsize = 2**i\n",
        "  print('\\nAttention size:', q_size, 'x', memsize)\n",
        "\n",
        "  query = fresh_qkv(q_size, input_dtype)\n",
        "  key = fresh_qkv(memsize, input_dtype)\n",
        "  value = fresh_qkv(memsize, input_dtype)\n",
        "\n",
        "  if execute_standard_att:\n",
        "    compilation_start = time.time()\n",
        "    _comp_res = diff_attention_simp(query, key, value)\n",
        "    for t in _comp_res:\n",
        "      t.block_until_ready()\n",
        "    print('Diff simp compilation time:', time.time() - compilation_start)\n",
        "\n",
        "    total_time_simp = 0.0\n",
        "    for _ in range(repeats):\n",
        "      start = time.time()\n",
        "      res_std = diff_attention_simp(query, key, value)\n",
        "      for t in res_std:\n",
        "        t.block_until_ready()\n",
        "      total_time_simp += (time.time() - start)\n",
        "    total_time_simp = total_time_simp / repeats\n",
        "    print('Standard attention took:', total_time_simp)\n",
        "\n",
        "  if execute_memory_efficient_att:\n",
        "    compilation_start = time.time()\n",
        "    _comp_res = diff_mefficient_attention(query, key, value)\n",
        "    for t in _comp_res:\n",
        "      t.block_until_ready()\n",
        "    print('Diff mem ckpt compilation time:', time.time() - compilation_start)\n",
        "\n",
        "    total_time_mem = 0.0\n",
        "    for _ in range(repeats):\n",
        "      start = time.time()\n",
        "      res = diff_mefficient_attention(query, key, value)\n",
        "      for t in res:\n",
        "        t.block_until_ready()\n",
        "      total_time_mem += (time.time() - start)\n",
        "    total_time_mem = total_time_mem / repeats\n",
        "    print('Memory-efficient attention took:', total_time_mem)\n",
        "\n",
        "  if execute_standard_att and execute_memory_efficient_att:\n",
        "    print('Performance advantage:', (total_time_simp / total_time_mem) - 1.0)\n",
        "    diff = res[0] - res_std[0]\n",
        "    print('avg difference', jnp.average(jnp.abs(diff)))\n",
        "    print('max difference', jnp.max(jnp.abs(diff)))\n",
        "    for tuple_idx in range(3):\n",
        "      np.testing.assert_allclose(\n",
        "          res[tuple_idx].astype(jnp.float32),\n",
        "          res_std[tuple_idx].astype(jnp.float32), atol=1e-2, rtol=1e-2)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "NJ6_WUtsngOA",
        "outputId": "e5637671-adcd-4128-fc94-29b2b8b71348"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "Attention size: 256 x 256\n",
            "Diff simp compilation time: 0.2330009937286377\n",
            "Standard attention took: 0.0023757553100585936\n",
            "Diff mem ckpt compilation time: 1.592576503753662\n",
            "Memory-efficient attention took: 0.0025052547454833983\n",
            "Performance advantage: -0.051691124688326706\n",
            "avg difference 3.00352e-08\n",
            "max difference 0.000488281\n",
            "\n",
            "Attention size: 512 x 512\n",
            "Diff simp compilation time: 0.2673332691192627\n",
            "Standard attention took: 0.003413748741149902\n",
            "Diff mem ckpt compilation time: 1.7921810150146484\n",
            "Memory-efficient attention took: 0.002515106201171875\n",
            "Performance advantage: 0.3572980495055511\n",
            "avg difference 9.66338e-13\n",
            "max difference 2.98023e-08\n",
            "\n",
            "Attention size: 1024 x 1024\n",
            "Diff simp compilation time: 0.37073850631713867\n",
            "Standard attention took: 0.002552280426025391\n",
            "Diff mem ckpt compilation time: 1.660691261291504\n",
            "Memory-efficient attention took: 0.002996225357055664\n",
            "Performance advantage: -0.14816807086451256\n",
            "avg difference 9.38599e-10\n",
            "max difference 6.10352e-05\n",
            "\n",
            "Attention size: 2048 x 2048\n",
            "Diff simp compilation time: 0.7198879718780518\n",
            "Standard attention took: 0.003253016471862793\n",
            "Diff mem ckpt compilation time: 2.225461721420288\n",
            "Memory-efficient attention took: 0.004697685241699219\n",
            "Performance advantage: -0.3075277919884366\n",
            "avg difference 3.0268e-09\n",
            "max difference 0.00012207\n",
            "\n",
            "Attention size: 4096 x 4096\n",
            "Diff simp compilation time: 1.1984167098999023\n",
            "Standard attention took: 0.006762046813964844\n",
            "Diff mem ckpt compilation time: 2.3629977703094482\n",
            "Memory-efficient attention took: 0.010511817932128907\n",
            "Performance advantage: -0.35671956481504996\n",
            "avg difference 3.55067e-09\n",
            "max difference 0.000244141\n",
            "\n",
            "Attention size: 8192 x 8192\n",
            "Diff simp compilation time: 2.0245492458343506\n",
            "Standard attention took: 0.019772610664367675\n",
            "Diff mem ckpt compilation time: 2.415445327758789\n",
            "Memory-efficient attention took: 0.03363493919372559\n",
            "Performance advantage: -0.4121407340597736\n",
            "avg difference 2.31666e-08\n",
            "max difference 0.000244141\n",
            "\n",
            "Attention size: 16384 x 16384\n",
            "Diff simp compilation time: 2.0951600074768066\n",
            "Standard attention took: 0.0703316593170166\n",
            "Diff mem ckpt compilation time: 2.444955587387085\n",
            "Memory-efficient attention took: 0.12465510368347169\n",
            "Performance advantage: -0.4357899737855495\n",
            "avg difference 1.9907e-08\n",
            "max difference 0.000244141\n"
          ]
        }
      ]
    }
  ]
}